import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Attention(nn.Module):
    def __init__(self, input_size, hidden_size, num_class, fc, num_char_embeddings=256):
        super(Attention, self).__init__()
        self.attention_cell = AttentionCell(
            input_size, hidden_size, num_char_embeddings
        )
        self.hidden_size = hidden_size
        self.num_class = num_class
        self.generator = fc
        self.num_char_embeddings = num_char_embeddings
        # self.generator = nn.Linear(hidden_size, num_class)
        self.char_embeddings = nn.Embedding(num_class, num_char_embeddings)
        # self.char_embeddings = nn.Linear(1, num_char_embeddings)

    def _char_to_onehot(self, input_char, onehot_dim=38):
        input_char = input_char.unsqueeze(1)
        batch_size = input_char.size(0)
        one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device)
        one_hot = one_hot.scatter_(1, input_char, 1)
        return one_hot

    def minmax(self,a):
        min_a = torch.min(a)
        max_a = torch.max(a)
        n2 = (a - min_a) / (max_a - min_a)
        return n2

    def cut_unknown(self,index):
        return torch.where(index >= self.num_class, 0, index)

    def forward(self, batch_H, text, is_train=True, batch_max_length=25):
        """
        input:
            batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels]
            text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [SOS] token. text[:, 0] = [SOS].
        output: probability distribution at each step [batch_size x num_steps x num_class]
        """
        batch_size = batch_H.size(0)
        num_steps = batch_max_length + 1  # +1 for [EOS] at end of sentence.

        output_hiddens = (torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(device))
        hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device),
                  torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device),)
        # text = self.minmax(text)

        self.context_history = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(device)
        self.alpha_history = []
        if is_train:
            for i in range(num_steps):
                # char_embeddings = self._char_to_onehot(text[:, i], onehot_dim=self.num_class)
                char_embeddings = self.char_embeddings(self.cut_unknown(text[:, i]))
                # char_embeddings = self.char_embeddings(text[:, i].unsqueeze(-1).float())
                # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_embeddings : f(y_{t-1})
                hidden, alpha, context = self.attention_cell(hidden, batch_H, char_embeddings)

                output_hiddens[:, i, :] = hidden[0]  # LSTM hidden index (0: hidden, 1: Cell)
                self.alpha_history.append(alpha)
                self.context_history[:, i, :] = context
            probs = self.generator(output_hiddens)

        else:
            targets = text[0].expand(batch_size)  # should be fill with [SOS] token
            probs = (torch.FloatTensor(batch_size, num_steps, self.num_class).fill_(0).to(device))

            for i in range(num_steps):
                # char_embeddings = self._char_to_onehot(targets, onehot_dim=self.num_class)
                char_embeddings = self.char_embeddings(self.cut_unknown(targets))
                # char_embeddings = self.char_embeddings(targets.unsqueeze(-1).float())
                hidden, alpha, context = self.attention_cell(hidden, batch_H, char_embeddings)
                probs_step = self.generator(hidden[0])
                probs[:, i, :] = probs_step
                _, next_input = probs_step.max(1)

                targets = next_input
                self.alpha_history.append(alpha)
                self.context_history[:, i, :] = context
        self.alpha_history = torch.cat(self.alpha_history, -1)
        self.alpha_history = self.alpha_history.permute(0, 2, 1)  # batch_size x num_steps x num_classes
        return probs  # batch_size x num_steps x num_class


class AttentionCell(nn.Module):
    def __init__(self, input_size, hidden_size, num_embeddings):
        super(AttentionCell, self).__init__()
        self.i2h = nn.Linear(input_size, hidden_size, bias=False)
        self.h2h = nn.Linear( hidden_size, hidden_size)  # either i2i or h2h should have bias
        self.score = nn.Linear(hidden_size, 1, bias=False)
        self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size)
        self.hidden_size = hidden_size

    def forward(self, prev_hidden, batch_H, char_embeddings):
        # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
        batch_H_proj = self.i2h(batch_H)
        prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
        e = self.score( torch.tanh(batch_H_proj + prev_hidden_proj))  # batch_size x num_encoder_step * 1

        alpha = F.softmax(e, dim=1)
        context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1)  # batch_size x num_channel
        concat_context = torch.cat([context, char_embeddings], 1)  # batch_size x (num_channel + num_embedding)
        cur_hidden = self.rnn(concat_context, prev_hidden)
        return cur_hidden, alpha, context
